import torch
import urllib
import cn_clip
from PIL import Image
import urllib.parse
import urllib.request
import traceback
from io import BytesIO


def image_open(file_path):
    try:
        if urllib.parse.urlparse(file_path).scheme in ('http', 'https',):
            return _open_image_from_url(file_path)
        else:
            return _open_image(file_path)
    except Exception as e:
        return _handle_exception(file_path)


def get_embedding(input_type, input_value, model_infos, batch_mode=False):
    preprocess = model_infos["preprocess"]
    device = model_infos["device"]
    model = model_infos["model"]

    if input_type == "image":
        image = _preprocess_image(input_value, preprocess, device, batch_mode)
        return _encode_image(model, image)
    else:
        text_tokens = _preprocess_text(input_value, preprocess, device, batch_mode)
        return _encode_text(model, text_tokens, batch_mode)


def _open_image_from_url(file_path):
    image_data = urllib.request.urlopen(file_path).read()
    return Image.open(BytesIO(image_data)).convert("RGB")


def _open_image(file_path):
    return Image.open(file_path).convert("RGB")


def _handle_exception(file_path):
    print(f"{file_path}: {traceback.format_exc()}")
    return Image.new('RGB', (100, 100), (0, 0, 0))  # Return a blank black image


def _preprocess_image(input_value, preprocess, device, batch_mode):
    if batch_mode:
        image = [preprocess(x).to(device) for x in input_value]
        image = torch.stack(image).to(device)
    else:
        image = preprocess(input_value).unsqueeze(0).to(device)
    return image


def _preprocess_text(input_value, preprocess, device, batch_mode):
    if batch_mode:
        text_tokens = cn_clip.clip.tokenize(input_value).to(device)
    else:
        text_tokens = cn_clip.clip.tokenize([input_value]).to(device)
    return text_tokens


def _encode_image(model, image):
    with torch.no_grad():
        image_features = model.encode_image(image)

    image_features /= image_features.norm(dim=-1, keepdim=True)
    return image_features.tolist()


def _encode_text(model, text_tokens, batch_mode):
    with torch.no_grad():
        text_features = model.encode_text(text_tokens)
        text_features /= text_features.norm(dim=-1, keepdim=True)
    return text_features.tolist() if batch_mode else text_features.tolist()[0]
